{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9aHTrpANhSoC"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WAnqNAzfcVu0"
      },
      "source": [
        "# Emotion prediction with GoEmotions and PRADO\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OcNoWgG7hvIs"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/research/seq_flow_lite/demo/colab/emotion_colab.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/research/seq_flow_lite/demo/colab/emotion_colab.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dhekoIrWiSsv"
      },
      "source": [
        "In this tutorial, we will work through training a neural emotion prediction model, using the tensorflow-models PIP package, and Bazel.\n",
        "\n",
        "This tutorial is using GoEmotions, an emotion prediction dataset, available on [TensorFlow TFDS](https://www.tensorflow.org/datasets/catalog/goemotions). We will be training a sequence projection model architecture named PRADO, available on [TensorFlow Model Garden](https://github.com/tensorflow/models/blob/master/research/seq_flow_lite/models/prado.py). Finally, we will examine an application of emotion prediction to emoji suggestions from text."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "grmac7ZYj02a"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4aGnloeD1Mfo"
      },
      "source": [
        "### Install Tensorflow 2.11.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MqP5qpAR1W4f"
      },
      "source": [
        "The seq_flow_lite library has been written with the assumption that tensorflow 2.11.0 will be used.  It may be necessary to restart the runtime after installing the correct version of Tensorflow."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hzuq_GVn1nXO"
      },
      "outputs": [],
      "source": [
        "!pip install tensorflow==2.11.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D_mi4NZeeB1l"
      },
      "source": [
        "### Install the TensorFlow Datasets pip package"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tCk46-HdmIyD"
      },
      "source": [
        "`tensorflow_datasets` is a set of collection of datasets that includes the GoEmotions dataset. We install it with pip."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mvO0_HcKx0_V"
      },
      "outputs": [],
      "source": [
        "!pip install tensorflow_datasets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p2wqyg-7mbfV"
      },
      "source": [
        "### Install the Sequence Projection Models package"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1JRZS_aSeINK"
      },
      "source": [
        "Install Bazel: This will allow us to build custom TensorFlow ops used by the PRADO architecture."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N00X4P229Ppm"
      },
      "outputs": [],
      "source": [
        "!sudo apt install curl gnupg\n",
        "!curl https://bazel.build/bazel-release.pub.gpg | sudo apt-key add -\n",
        "!echo \"deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8\" | sudo tee /etc/apt/sources.list.d/bazel.list\n",
        "!sudo apt update\n",
        "!sudo apt install bazel=5.4.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9JeSDpZFelL5"
      },
      "source": [
        "Install the library:\n",
        "* `seq_flow_lite` includes the PRADO architecture and custom ops.\n",
        "* We download the code from GitHub, and then build and install the TF and TFLite ops used by the model.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mktlCYcd9iLG"
      },
      "outputs": [],
      "source": [
        "!git clone https://www.github.com/tensorflow/models\n",
        "!models/research/seq_flow_lite/demo/colab/setup_workspace.sh\n",
        "!pip install models/research/seq_flow_lite\n",
        "!rm -rf models/research/seq_flow_lite/tf_ops\n",
        "!rm -rf models/research/seq_flow_lite/tflite_ops"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rP8iKa4Il4mL"
      },
      "source": [
        "## Training an Emotion Prediction Model\n",
        "\n",
        "* First, we load the GoEmotions data from TFDS.\n",
        "* Next, we prepare the PRADO model for training. We set up the model configuration, including hyperparameters and labels. We also prepare the dataset, which involves projecting the inputs from the dataset, and passing the projections to the model.  This is needed because a model training on TPU can not handle string inputs.\n",
        "* Finally, we train and evaluate the model and produce model-level and per-label metrics."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YtvR40K8K0Bn"
      },
      "source": [
        "***Start here on Runtime reset***, once the packages above are properly installed:\n",
        "* Go to the `seq_flow_lite` directory."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ImEejssVKvxR"
      },
      "outputs": [],
      "source": [
        "%cd models/research/seq_flow_lite"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uwSPqHXAeQ6H"
      },
      "source": [
        "* Import the Tensorflow and Tensorflow Dataset libraries."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kc4y4n80eL_b"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j-CtG3cagPgl"
      },
      "source": [
        "### The data: GoEmotions\n",
        "In this tutorial, we use the [GoEmotions dataset from TFDS](https://www.tensorflow.org/datasets/catalog/goemotions).\n",
        "\n",
        "GoEmotions is a corpus of comments extracted from Reddit, with human annotations to 27 emotion categories or Neutral.\n",
        "\n",
        "*   Number of labels: 27.\n",
        "*   Size of training dataset: 43,410.\n",
        "*   Size of evaluation dataset: 5,427.\n",
        "*   Maximum sequence length in training and evaluation datasets: 30.\n",
        "\n",
        "The emotion categories are admiration, amusement, anger, annoyance, approval, caring, confusion, curiosity, desire, disappointment, disapproval, disgust, embarrassment, excitement, fear, gratitude, grief, joy, love, nervousness, optimism, pride, realization, relief, remorse, sadness, surprise.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bvsn_s3S0SAt"
      },
      "source": [
        "Load the data from TFDS:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KtTLwtEqwcR2"
      },
      "outputs": [],
      "source": [
        "ds = tfds.load('goemotions', split='train')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gJuu4jKet9zq"
      },
      "source": [
        "Print 5 sample data elements from the dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y0O18rSLuDx5"
      },
      "outputs": [],
      "source": [
        "for element in ds.take(5):\n",
        "  print(element)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UAz-tdQfuVBn"
      },
      "source": [
        "### The model: PRADO\n",
        "\n",
        "We train an Emotion Prediction model, based on the [PRADO architecture](https://github.com/tensorflow/models/blob/master/research/seq_flow_lite/models/prado.py) from the [Sequence Projection Models package](https://github.com/tensorflow/models/tree/master/research/seq_flow_lite).\n",
        "\n",
        "PRADO projects input sequences to fixed sized features. The idea behind this approach is to build embedding-free models that minimize the model size. Instead of using an embedding table to lookup embeddings, sequence projection models compute them on the fly, resulting in space-efficient models.\n",
        "\n",
        "In this section, we prepare the PRADO model for training.\n",
        "\n",
        "This GoEmotions dataset is not set up so that it can be directly fed into the PRADO model, so below, we also handle the necessary preprocessing by providing a dataset builder."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e9uPSZYpgBqP"
      },
      "source": [
        "Prepare the model configuration:\n",
        "* Enumerate the labels expected to be found in the GoEmotions dataset.\n",
        "* Prepare the `MODEL_CONFIG` dictionary which includes training parameters for the model. See sample configs for the PRADO model [here](https://github.com/tensorflow/models/tree/master/research/seq_flow_lite/configs)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DkQMnTcLyFeR"
      },
      "outputs": [],
      "source": [
        "LABELS = [\n",
        "    'admiration',\n",
        "    'amusement',\n",
        "    'anger',\n",
        "    'annoyance',\n",
        "    'approval',\n",
        "    'caring',\n",
        "    'confusion',\n",
        "    'curiosity',\n",
        "    'desire',\n",
        "    'disappointment',\n",
        "    'disapproval',\n",
        "    'disgust',\n",
        "    'embarrassment',\n",
        "    'excitement',\n",
        "    'fear',\n",
        "    'gratitude',\n",
        "    'grief',\n",
        "    'joy',\n",
        "    'love',\n",
        "    'nervousness',\n",
        "    'optimism',\n",
        "    'pride',\n",
        "    'realization',\n",
        "    'relief',\n",
        "    'remorse',\n",
        "    'sadness',\n",
        "    'surprise',\n",
        "    'neutral',\n",
        "]\n",
        "\n",
        "# Model training parameters.\n",
        "CONFIG = {\n",
        "    'name': 'models.prado',\n",
        "    'batch_size': 1024,\n",
        "    'train_steps': 10000,\n",
        "    'learning_rate': 0.0006,\n",
        "    'learning_rate_decay_steps': 340,\n",
        "    'learning_rate_decay_rate': 0.7,\n",
        "}\n",
        "\n",
        "# Limits the amount of logging output produced by the training run, in order to\n",
        "# avoid browser slowdowns.\n",
        "CONFIG['save_checkpoints_steps'] = int(CONFIG['train_steps'] / 10)\n",
        "\n",
        "MODEL_CONFIG = {\n",
        "    'labels': LABELS,\n",
        "    'multilabel': True,\n",
        "    'quantize': False,\n",
        "    'max_seq_len': 128,\n",
        "    'max_seq_len_inference': 128,\n",
        "    'exclude_nonalphaspace_unicodes': False,\n",
        "    'split_on_space': True,\n",
        "    'embedding_regularizer_scale': 0.035,\n",
        "    'embedding_size': 64,\n",
        "    'bigram_channels': 64,\n",
        "    'trigram_channels': 64,\n",
        "    'feature_size': 512,\n",
        "    'network_regularizer_scale': 0.0001,\n",
        "    'keep_prob': 0.5,\n",
        "    'word_novelty_bits': 0,\n",
        "    'doc_size_levels': 0,\n",
        "    'add_bos_tag': False,\n",
        "    'add_eos_tag': False,\n",
        "    'pre_logits_fc_layers': [],\n",
        "    'text_distortion_probability': 0.0,\n",
        "}\n",
        "\n",
        "CONFIG['model_config'] = MODEL_CONFIG"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R-pUW649gfzA"
      },
      "source": [
        "Write a function that builds the datasets for the model.  It will load the data, handle batching, and generate projections for the input text."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "unYlUYXq119f"
      },
      "outputs": [],
      "source": [
        "from layers import base_layers\n",
        "from layers import projection_layers\n",
        "\n",
        "def build_dataset(mode, inspect=False):\n",
        "  if mode == base_layers.TRAIN:\n",
        "    split = 'train'\n",
        "    count = None\n",
        "  elif mode == base_layers.EVAL:\n",
        "    split = 'test'\n",
        "    count = 1\n",
        "  else:\n",
        "    raise ValueError('mode={}, must be TRAIN or EVAL'.format(mode))\n",
        "\n",
        "  batch_size = CONFIG['batch_size']\n",
        "  if inspect:\n",
        "    batch_size = 1\n",
        "\n",
        "  # Convert examples from their dataset format into the model format.\n",
        "  def process_input(features):\n",
        "    # Generate the projection for each comment_text input.  The final tensor \n",
        "    # will have the shape [batch_size, number of tokens, feature size].\n",
        "    # Additionally, we generate a tensor containing the number of tokens for\n",
        "    # each comment_text (seq_length).  This is needed because the projection\n",
        "    # tensor is a full tensor, and we are not using EOS tokens.\n",
        "    text = features['comment_text']\n",
        "    text = tf.reshape(text, [batch_size])\n",
        "    projection_layer = projection_layers.ProjectionLayer(MODEL_CONFIG, mode)\n",
        "    projection, seq_length = projection_layer(text)\n",
        "\n",
        "    # Convert the labels into an indicator tensor, using the LABELS indices.\n",
        "    label = tf.stack([features[label] for label in LABELS], axis=-1)\n",
        "    label = tf.cast(label, tf.float32)\n",
        "    label = tf.reshape(label, [batch_size, len(LABELS)])\n",
        "\n",
        "    model_features = ({'projection': projection, 'sequence_length': seq_length}, label)\n",
        "\n",
        "    if inspect:\n",
        "      model_features = (model_features[0], model_features[1], features)\n",
        "\n",
        "    return model_features\n",
        "\n",
        "  ds = tfds.load('goemotions', split=split)\n",
        "  ds = ds.repeat(count=count)\n",
        "  ds = ds.shuffle(buffer_size=batch_size * 2)\n",
        "  ds = ds.batch(batch_size, drop_remainder=True)\n",
        "  ds = ds.map(process_input,\n",
        "              num_parallel_calls=tf.data.experimental.AUTOTUNE,\n",
        "              deterministic=False)\n",
        "  ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n",
        "  return ds\n",
        "\n",
        "train_dataset = build_dataset(base_layers.TRAIN)\n",
        "test_dataset = build_dataset(base_layers.EVAL)\n",
        "inspect_dataset = build_dataset(base_layers.TRAIN, inspect=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DQmYWg6ivCHS"
      },
      "source": [
        "Print a batch of examples in model format.  This will consist of:\n",
        "* the projection tensors (projection and seq_length)\n",
        "* the label tensor (second tuple value)\n",
        "\n",
        "The projection tensor is a **[batch size, max_seq_length, feature_size]** floating point tensor.  The **[b, i]** vector is a feature vector of the **i**th token of the **b**th comment_text.  The rest of the tensor is zero-padded, and the\n",
        "seq_length tensor indicates the number of features vectors for each comment_text.\n",
        "\n",
        "The label tensor is an indicator tensor of the set of true labels for the example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1OyK7rjTvBjF"
      },
      "outputs": [],
      "source": [
        "example = next(iter(train_dataset))\n",
        "print(\"inputs = {}\".format(example[0]))\n",
        "print(\"labels = {}\".format(example[1]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ytMQHT5Kd7A_"
      },
      "source": [
        "In this version of the dataset, the original example has been added as the third element of the tuple."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "29EzRoCfI91r"
      },
      "outputs": [],
      "source": [
        "example = next(iter(inspect_dataset))\n",
        "print(\"inputs = {}\".format(example[0]))\n",
        "print(\"labels = {}\".format(example[1]))\n",
        "print(\"original example = {}\".format(example[2]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CLDbHTIvvX11"
      },
      "source": [
        "### Train and Evaluate"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QqUTa7wXsHoO"
      },
      "source": [
        "First we define a function to build the model.  We vary the model inputs depending on task.  For training and evaluation, we'll take the projection and sequence length as inputs.  Otherwise, we'll take strings as inputs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "erEiNX3ToLZ1"
      },
      "outputs": [],
      "source": [
        "from models import prado\n",
        "\n",
        "def build_model(mode):\n",
        "  # First we define our inputs.\n",
        "  inputs = []\n",
        "  if mode == base_layers.TRAIN or mode == base_layers.EVAL:\n",
        "    # For TRAIN and EVAL, we'll be getting dataset examples,\n",
        "    # so we'll get projections and sequence_lengths.\n",
        "    projection = tf.keras.Input(\n",
        "        shape=(MODEL_CONFIG['max_seq_len'], MODEL_CONFIG['feature_size']),\n",
        "        name='projection',\n",
        "        dtype='float32')\n",
        "\n",
        "    sequence_length = tf.keras.Input(\n",
        "        shape=(), name='sequence_length', dtype='float32')\n",
        "    inputs = [projection, sequence_length]\n",
        "  else:\n",
        "    # Otherwise, we get string inputs which we need to project.\n",
        "    input = tf.keras.Input(shape=(), name='input', dtype='string')\n",
        "    projection_layer = projection_layers.ProjectionLayer(MODEL_CONFIG, mode)\n",
        "    projection, sequence_length = projection_layer(input)\n",
        "    inputs = [input]\n",
        "\n",
        "  # Next we add the model layer.\n",
        "  model_layer = prado.Encoder(MODEL_CONFIG, mode)\n",
        "  logits = model_layer(projection, sequence_length)\n",
        "\n",
        "  # Finally we add an activation layer.\n",
        "  if MODEL_CONFIG['multilabel']:\n",
        "    activation = tf.keras.layers.Activation('sigmoid', name='predictions')\n",
        "  else:\n",
        "    activation = tf.keras.layers.Activation('softmax', name='predictions')\n",
        "  predictions = activation(logits)\n",
        "\n",
        "  model = tf.keras.Model(\n",
        "      inputs=inputs,\n",
        "      outputs=[predictions])\n",
        "  \n",
        "  return model\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "caHpK9Htv40g"
      },
      "source": [
        "Train the model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2xM-2R38kogo"
      },
      "outputs": [],
      "source": [
        "# Remove any previous training data.\n",
        "!rm -rf model\n",
        "\n",
        "model = build_model(base_layers.TRAIN)\n",
        "\n",
        "# Create the optimizer.\n",
        "learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(\n",
        "    initial_learning_rate=CONFIG['learning_rate'],\n",
        "    decay_rate=CONFIG['learning_rate_decay_rate'],\n",
        "    decay_steps=CONFIG['learning_rate_decay_steps'],\n",
        "    staircase=True)\n",
        "\n",
        "optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n",
        "\n",
        "# Define the loss function.\n",
        "loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)\n",
        "\n",
        "model.compile(optimizer=optimizer, loss=loss)\n",
        "\n",
        "epochs = int(CONFIG['train_steps'] / CONFIG['save_checkpoints_steps'])\n",
        "model.fit(\n",
        "    x=train_dataset,\n",
        "    epochs=epochs,\n",
        "    validation_data=test_dataset,\n",
        "    steps_per_epoch=CONFIG['save_checkpoints_steps'])\n",
        "\n",
        "model.save_weights('model/model_checkpoint')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0hdbXBs0g3oX"
      },
      "source": [
        "Load a training checkpoint and evaluate:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A1qc9GNtF3s5"
      },
      "outputs": [],
      "source": [
        "model = build_model(base_layers.EVAL)\n",
        "\n",
        "# Define metrics over each category.\n",
        "metrics = []\n",
        "for i, label in enumerate(LABELS):\n",
        "  metric = tf.keras.metrics.Precision(\n",
        "      thresholds=[0.5],\n",
        "      class_id=i,\n",
        "      name='precision@0.5/{}'.format(label))\n",
        "  metrics.append(metric)\n",
        "  metric = tf.keras.metrics.Recall(\n",
        "      thresholds=[0.5],\n",
        "      class_id=i,\n",
        "      name='recall@0.5/{}'.format(label))\n",
        "  metrics.append(metric)\n",
        "\n",
        "# Define metrics over the entire task.\n",
        "metric = tf.keras.metrics.Precision(thresholds=[0.5], name='precision@0.5/all')\n",
        "metrics.append(metric)\n",
        "metric = tf.keras.metrics.Recall(thresholds=[0.5], name='recall@0.5/all')\n",
        "metrics.append(metric)\n",
        "\n",
        "model.compile(metrics=metrics)\n",
        "model.load_weights('model/model_checkpoint')\n",
        "result = model.evaluate(x=test_dataset, return_dict=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Namwa3enwQBc"
      },
      "source": [
        "Print evaluation metrics for the model, as well as per emotion label:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l420PosisfXN"
      },
      "outputs": [],
      "source": [
        "for label in LABELS:\n",
        "  precision_key = 'precision@0.5/{}'.format(label)\n",
        "  recall_key = 'recall@0.5/{}'.format(label)\n",
        "  if precision_key in result and recall_key in result:\n",
        "    print('{}: (precision@0.5: {}, recall@0.5: {})'.format(\n",
        "        label, result[precision_key], result[recall_key]))\n",
        "    \n",
        "precision_key = 'precision@0.5/all'\n",
        "recall_key = 'recall@0.5/all'\n",
        "if precision_key in result and recall_key in result:\n",
        "  print('all: (precision@0.5: {}, recall@0.5: {})'.format(\n",
        "      result[precision_key], result[recall_key]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AZSWnwTMqZ5f"
      },
      "source": [
        "## Suggest Emojis using an Emotion Prediction model\n",
        "\n",
        "In this section, we apply the Emotion Prediction model trained above to suggest emojis relevant to input text.\n",
        "\n",
        "Refer to our [GoEmotions Model Card](https://github.com/google-research/google-research/blob/master/goemotions/goemotions_model_card.pdf) for additional uses of the model and considerations and limitations for using the GoEmotions data."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aybpGQV1qr8I"
      },
      "source": [
        "Map each emotion label to a relevant emoji:\n",
        "* Emotions are subtle and multi-faceted. In many cases, no one emoji can truely capture the full complexity of the human experience behind each emotion. \n",
        "* For the purpose of this exercise, we will select an emoji that captures at least one facet that is conveyed by an emotion label."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lgs12b90qmSQ"
      },
      "outputs": [],
      "source": [
        "EMOJI_MAP = {\n",
        "    'admiration': '👏',\n",
        "    'amusement': '😂',\n",
        "    'anger': '😡',\n",
        "    'annoyance': '😒',\n",
        "    'approval': '👍',\n",
        "    'caring': '🤗',\n",
        "    'confusion': '😕',\n",
        "    'curiosity': '🤔',\n",
        "    'desire': '😍',\n",
        "    'disappointment': '😞',\n",
        "    'disapproval': '👎',\n",
        "    'disgust': '🤮',\n",
        "    'embarrassment': '😳',\n",
        "    'excitement': '🤩',\n",
        "    'fear': '😨',\n",
        "    'gratitude': '🙏',\n",
        "    'grief': '😢',\n",
        "    'joy': '😃',\n",
        "    'love': '❤️',\n",
        "    'nervousness': '😬',\n",
        "    'optimism': '🤞',\n",
        "    'pride': '😌',\n",
        "    'realization': '💡',\n",
        "    'relief': '😅',\n",
        "    'remorse': '',\n",
        "    'sadness': '😞',\n",
        "    'surprise': '😲',\n",
        "    'neutral': '',\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rh_3y7OL7JG_"
      },
      "source": [
        "Select sample inputs:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rdD6xPpn7Mjm"
      },
      "outputs": [],
      "source": [
        "PREDICT_TEXT = [\n",
        "  b'Good for you!',\n",
        "  b'Happy birthday!',\n",
        "  b'I love you.',\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vavivya6hGw0"
      },
      "source": [
        "Run inference for the selected examples:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tJ6iyLlLo5-3"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "model = build_model(base_layers.PREDICT)\n",
        "model.load_weights('model/model_checkpoint')\n",
        "\n",
        "for text in PREDICT_TEXT:\n",
        "  results = model.predict(x=[text])\n",
        "  print('')\n",
        "  print('{}:'.format(text))\n",
        "  labels = np.flip(np.argsort(results[0]))\n",
        "  for x in range(3):\n",
        "    label = LABELS[labels[x]]\n",
        "    label = EMOJI_MAP[label] if EMOJI_MAP[label] else label\n",
        "    print('{}: {}'.format(label, results[0][labels[x]]))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "Emotion prediction with GoEmotions and PRADO",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
